{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# COVID-19 Drug Repurposing via disease-compounds relations\n",
    "This example shows how to do drug repurposing using DRKG even with the pretrained model."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Collecting COVID-19 related disease\n",
    "At the very beginning we need to collect a list of disease of Corona-Virus(COV) in DRKG. We can easily use the Disease ID that DRKG uses for encoding the disease. Here we take all of the COV disease as target."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "COV_disease_list = [\n",
    "'Disease::SARS-CoV2 E',\n",
    "'Disease::SARS-CoV2 M',\n",
    "'Disease::SARS-CoV2 N',\n",
    "'Disease::SARS-CoV2 Spike',\n",
    "'Disease::SARS-CoV2 nsp1',\n",
    "'Disease::SARS-CoV2 nsp10',\n",
    "'Disease::SARS-CoV2 nsp11',\n",
    "'Disease::SARS-CoV2 nsp12',\n",
    "'Disease::SARS-CoV2 nsp13',\n",
    "'Disease::SARS-CoV2 nsp14',\n",
    "'Disease::SARS-CoV2 nsp15',\n",
    "'Disease::SARS-CoV2 nsp2',\n",
    "'Disease::SARS-CoV2 nsp4',\n",
    "'Disease::SARS-CoV2 nsp5',\n",
    "'Disease::SARS-CoV2 nsp5_C145A',\n",
    "'Disease::SARS-CoV2 nsp6',\n",
    "'Disease::SARS-CoV2 nsp7',\n",
    "'Disease::SARS-CoV2 nsp8',\n",
    "'Disease::SARS-CoV2 nsp9',\n",
    "'Disease::SARS-CoV2 orf10',\n",
    "'Disease::SARS-CoV2 orf3a',\n",
    "'Disease::SARS-CoV2 orf3b',\n",
    "'Disease::SARS-CoV2 orf6',\n",
    "'Disease::SARS-CoV2 orf7a',\n",
    "'Disease::SARS-CoV2 orf8',\n",
    "'Disease::SARS-CoV2 orf9b',\n",
    "'Disease::SARS-CoV2 orf9c',\n",
    "'Disease::MESH:D045169',\n",
    "'Disease::MESH:D045473',\n",
    "'Disease::MESH:D001351',\n",
    "'Disease::MESH:D065207',\n",
    "'Disease::MESH:D028941',\n",
    "'Disease::MESH:D058957',\n",
    "'Disease::MESH:D006517'\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Candidate drugs\n",
    "Now we use FDA-approved drugs in Drugbank as candidate drugs. (we exclude drugs with molecule weight < 250) The drug list is in infer\\_drug.tsv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import csv\n",
    "\n",
    "# Load entity file\n",
    "drug_list = []\n",
    "with open(\"./infer_drug.tsv\", newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['drug','ids'])\n",
    "    for row_val in reader:\n",
    "        drug_list.append(row_val['drug'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8104"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(drug_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Treatment relation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Two treatment relations in this context"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "treatment = ['Hetionet::CtD::Compound:Disease','GNBR::T::Compound:Disease']"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get pretrained model\n",
    "We can directly use the pretrianed model to do drug repurposing."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Download finished. Unzipping the file...\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "import numpy as np\n",
    "import sys\n",
    "sys.path.insert(1, '../utils')\n",
    "from utils import download_and_extract\n",
    "download_and_extract()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "entity_idmap_file = '../data/drkg/embed/entities.tsv'\n",
    "relation_idmap_file = '../data/drkg/embed/relations.tsv'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Get embeddings for diseases and drugs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Get drugname/disease name to entity ID mappings\n",
    "entity_map = {}\n",
    "entity_id_map = {}\n",
    "relation_map = {}\n",
    "with open(entity_idmap_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['name','id'])\n",
    "    for row_val in reader:\n",
    "        entity_map[row_val['name']] = int(row_val['id'])\n",
    "        entity_id_map[int(row_val['id'])] = row_val['name']\n",
    "        \n",
    "with open(relation_idmap_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['name','id'])\n",
    "    for row_val in reader:\n",
    "        relation_map[row_val['name']] = int(row_val['id'])\n",
    "        \n",
    "# handle the ID mapping\n",
    "drug_ids = []\n",
    "disease_ids = []\n",
    "for drug in drug_list:\n",
    "    drug_ids.append(entity_map[drug])\n",
    "    \n",
    "for disease in COV_disease_list:\n",
    "    disease_ids.append(entity_map[disease])\n",
    "\n",
    "treatment_rid = [relation_map[treat]  for treat in treatment]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Load embeddings\n",
    "import torch as th\n",
    "entity_emb = np.load('../data/drkg/embed/DRKG_TransE_l2_entity.npy')\n",
    "rel_emb = np.load('../data/drkg/embed/DRKG_TransE_l2_relation.npy')\n",
    "\n",
    "drug_ids = th.tensor(drug_ids).long()\n",
    "disease_ids = th.tensor(disease_ids).long()\n",
    "treatment_rid = th.tensor(treatment_rid)\n",
    "\n",
    "drug_emb = th.tensor(entity_emb[drug_ids])\n",
    "treatment_embs = [th.tensor(rel_emb[rid]) for rid in treatment_rid]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Drug Repurposing Based on Edge Score\n",
    "We use following algorithm to calculate the edge score. Note, here we use logsigmiod to make all scores < 0. The larger the score is, the stronger the $h$ will have $r$ with $t$.\n",
    "\n",
    "$\\mathbf{d} = \\gamma - ||\\mathbf{h}+\\mathbf{r}-\\mathbf{t}||_{2}$\n",
    "\n",
    "$\\mathbf{score} = \\log\\left(\\frac{1}{1+\\exp(\\mathbf{-d})}\\right)$\n",
    "\n",
    "When doing drug repurposing, we only use the treatment related relations."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn.functional as fn\n",
    "\n",
    "gamma=12.0\n",
    "def transE_l2(head, rel, tail):\n",
    "    score = head + rel - tail\n",
    "    return gamma - th.norm(score, p=2, dim=-1)\n",
    "\n",
    "scores_per_disease = []\n",
    "dids = []\n",
    "for rid in range(len(treatment_embs)):\n",
    "    treatment_emb=treatment_embs[rid]\n",
    "    for disease_id in disease_ids:\n",
    "        disease_emb = entity_emb[disease_id]\n",
    "        score = fn.logsigmoid(transE_l2(drug_emb, treatment_emb, disease_emb))\n",
    "        scores_per_disease.append(score)\n",
    "        dids.append(drug_ids)\n",
    "scores = th.cat(scores_per_disease)\n",
    "dids = th.cat(dids)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sort scores in decending order\n",
    "idx = th.flip(th.argsort(scores), dims=[0])\n",
    "scores = scores[idx].numpy()\n",
    "dids = dids[idx].numpy()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now we output proposed treatments"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "_, unique_indices = np.unique(dids, return_index=True)\n",
    "topk=100\n",
    "topk_indices = np.sort(unique_indices)[:topk]\n",
    "proposed_dids = dids[topk_indices]\n",
    "proposed_scores = scores[topk_indices]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Now we list the pairs of in form of (drug, treat, disease, score) \n",
    "\n",
    "We select top K relevent drugs according the edge score"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compound::DB00811\t-0.21416784822940826\n",
      "Compound::DB00993\t-0.8350892663002014\n",
      "Compound::DB00635\t-0.8974801898002625\n",
      "Compound::DB01082\t-0.9854875802993774\n",
      "Compound::DB01234\t-0.9984006881713867\n",
      "Compound::DB00982\t-1.0160722732543945\n",
      "Compound::DB00563\t-1.0189464092254639\n",
      "Compound::DB00290\t-1.064104437828064\n",
      "Compound::DB01394\t-1.080674648284912\n",
      "Compound::DB01222\t-1.084547519683838\n",
      "Compound::DB00415\t-1.0853980779647827\n",
      "Compound::DB01004\t-1.096668004989624\n",
      "Compound::DB00860\t-1.1004775762557983\n",
      "Compound::DB00681\t-1.1011559963226318\n",
      "Compound::DB00688\t-1.125687599182129\n",
      "Compound::DB00624\t-1.1428285837173462\n",
      "Compound::DB00959\t-1.1618402004241943\n",
      "Compound::DB00115\t-1.1868144273757935\n",
      "Compound::DB00091\t-1.1906721591949463\n",
      "Compound::DB01024\t-1.2051165103912354\n",
      "Compound::DB00741\t-1.2147064208984375\n",
      "Compound::DB00441\t-1.2320444583892822\n",
      "Compound::DB00158\t-1.2346539497375488\n",
      "Compound::DB00499\t-1.2525147199630737\n",
      "Compound::DB00929\t-1.2730510234832764\n",
      "Compound::DB00770\t-1.2825534343719482\n",
      "Compound::DB01331\t-1.2960500717163086\n",
      "Compound::DB00958\t-1.2967796325683594\n",
      "Compound::DB02527\t-1.303438663482666\n",
      "Compound::DB00196\t-1.3053392171859741\n",
      "Compound::DB00537\t-1.3131829500198364\n",
      "Compound::DB00644\t-1.3131871223449707\n",
      "Compound::DB01048\t-1.3267226219177246\n",
      "Compound::DB00552\t-1.3272088766098022\n",
      "Compound::DB00328\t-1.3286101818084717\n",
      "Compound::DB00171\t-1.3300385475158691\n",
      "Compound::DB01212\t-1.3330755233764648\n",
      "Compound::DB09093\t-1.3382999897003174\n",
      "Compound::DB00783\t-1.338560938835144\n",
      "Compound::DB09341\t-1.3396968841552734\n",
      "Compound::DB00558\t-1.3425884246826172\n",
      "Compound::DB05382\t-1.3575129508972168\n",
      "Compound::DB01112\t-1.3584508895874023\n",
      "Compound::DB00515\t-1.3608112335205078\n",
      "Compound::DB01101\t-1.381548523902893\n",
      "Compound::DB01165\t-1.3838160037994385\n",
      "Compound::DB01183\t-1.3862146139144897\n",
      "Compound::DB00815\t-1.3863483667373657\n",
      "Compound::DB00755\t-1.3881785869598389\n",
      "Compound::DB00198\t-1.3885014057159424\n",
      "Compound::DB00480\t-1.3935325145721436\n",
      "Compound::DB00806\t-1.3996552228927612\n",
      "Compound::DB01656\t-1.3999741077423096\n",
      "Compound::DB00759\t-1.404650092124939\n",
      "Compound::DB00917\t-1.4116020202636719\n",
      "Compound::DB01181\t-1.4148889780044556\n",
      "Compound::DB01039\t-1.4176580905914307\n",
      "Compound::DB00512\t-1.4207379817962646\n",
      "Compound::DB01233\t-1.4211887121200562\n",
      "Compound::DB11996\t-1.425789475440979\n",
      "Compound::DB00738\t-1.4274098873138428\n",
      "Compound::DB00716\t-1.4327492713928223\n",
      "Compound::DB03461\t-1.437927484512329\n",
      "Compound::DB00591\t-1.4404338598251343\n",
      "Compound::DB01327\t-1.4408743381500244\n",
      "Compound::DB00131\t-1.4446886777877808\n",
      "Compound::DB00693\t-1.4460749626159668\n",
      "Compound::DB00369\t-1.4505752325057983\n",
      "Compound::DB04630\t-1.453115463256836\n",
      "Compound::DB00878\t-1.456466555595398\n",
      "Compound::DB08818\t-1.4633680582046509\n",
      "Compound::DB00682\t-1.4691765308380127\n",
      "Compound::DB01068\t-1.4700121879577637\n",
      "Compound::DB00446\t-1.4720206260681152\n",
      "Compound::DB01115\t-1.4729849100112915\n",
      "Compound::DB00355\t-1.4770021438598633\n",
      "Compound::DB01030\t-1.485068678855896\n",
      "Compound::DB00620\t-1.4973516464233398\n",
      "Compound::DB00396\t-1.4976921081542969\n",
      "Compound::DB01073\t-1.4987037181854248\n",
      "Compound::DB00640\t-1.5026229619979858\n",
      "Compound::DB00999\t-1.5034282207489014\n",
      "Compound::DB01060\t-1.504364252090454\n",
      "Compound::DB00493\t-1.5072362422943115\n",
      "Compound::DB01240\t-1.5090957880020142\n",
      "Compound::DB00364\t-1.509944200515747\n",
      "Compound::DB01263\t-1.511993169784546\n",
      "Compound::DB00746\t-1.513066053390503\n",
      "Compound::DB00718\t-1.5183149576187134\n",
      "Compound::DB01065\t-1.5207160711288452\n",
      "Compound::DB01205\t-1.521277904510498\n",
      "Compound::DB01137\t-1.5229592323303223\n",
      "Compound::DB08894\t-1.5239660739898682\n",
      "Compound::DB00813\t-1.5308701992034912\n",
      "Compound::DB01157\t-1.5316557884216309\n",
      "Compound::DB04570\t-1.5430843830108643\n",
      "Compound::DB00459\t-1.5503207445144653\n",
      "Compound::DB01752\t-1.5541703701019287\n",
      "Compound::DB00775\t-1.5559712648391724\n",
      "Compound::DB01610\t-1.5563474893569946\n"
     ]
    }
   ],
   "source": [
    "for i in range(topk):\n",
    "    drug = int(proposed_dids[i])\n",
    "    score = proposed_scores[i]\n",
    "    \n",
    "    print(\"{}\\t{}\".format(entity_id_map[drug], score))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Check Clinial Trial Drugs\n",
    "There are seven clinial trial drugs hit in top100. (Note: Ribavirin exists in DRKG as a treatment for SARS)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\tRibavirin\t-0.21416784822940826\n",
      "[4]\tDexamethasone\t-0.9984006881713867\n",
      "[8]\tColchicine\t-1.080674648284912\n",
      "[16]\tMethylprednisolone\t-1.1618402004241943\n",
      "[49]\tOseltamivir\t-1.3885014057159424\n",
      "[87]\tDeferoxamine\t-1.513066053390503\n"
     ]
    }
   ],
   "source": [
    "clinical_drugs_file = './COVID19_clinical_trial_drugs.tsv'\n",
    "clinical_drug_map = {}\n",
    "with open(clinical_drugs_file, newline='', encoding='utf-8') as csvfile:\n",
    "    reader = csv.DictReader(csvfile, delimiter='\\t', fieldnames=['id', 'drug_name','drug_id'])\n",
    "    for row_val in reader:\n",
    "        clinical_drug_map[row_val['drug_id']] = row_val['drug_name']\n",
    "        \n",
    "for i in range(topk):\n",
    "    drug = entity_id_map[int(proposed_dids[i])][10:17]\n",
    "    if clinical_drug_map.get(drug, None) is not None:\n",
    "        score = proposed_scores[i]\n",
    "        print(\"[{}]\\t{}\\t{}\".format(i, clinical_drug_map[drug],score , proposed_scores[i]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "32"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(clinical_drug_map)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.6"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
